syntax = "proto3";

package xrt;

import "tensorflow/compiler/tf2xla/host_compute_metadata.proto";
import "tensorflow/compiler/xla/service/hlo.proto";
import "tensorflow/compiler/xla/xla.proto";
import "tensorflow/compiler/xla/xla_data.proto";

message DeviceAssignment {
  message ComputationDevice {
    message DeviceMeshCoordinates {
      // The mesh coordinates for the device. Usually (X, Y, Z, Core), in the
      // order in which they are returned in the TopologyProto.
      //  X    = value(0)
      //  Y    = value(1)
      //  Z    = value(2)
      //  Core = value(3)
      repeated int32 value = 1;
    }
    // As many replicas as there are in the replicated computation.
    repeated DeviceMeshCoordinates replica_devices = 1;
  }
  // As many ComputationDevice as many there are computations (number
  // of cores per replica).
  repeated ComputationDevice computation_devices = 1;
}

// Options for an XLA compilation.
message XLAComputationConfig {
  // The number of replicas the computation will be run on. If this is
  // default (0) it is interpreted as 1.
  int32 num_replicas = 1;
  // The number of "model-parallel" cores per replica. If this is
  // default (0) it is interpreted as 1.
  int32 num_cores_per_replica = 2;
  // Optional metadata about host sends and recvs.
  tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3;

  // The arg/result shapes for the whole computation.
  xla.ProgramShapeProto program_shape = 4;
  // The arg/result shapes for each core of a model-parallel
  // computation. per_core_args_and_result_shapes is optional for a
  // single-core computation.
  repeated xla.ProgramShapeProto per_core_program_shape = 5;
  // Describes how replicated computation instances should be assigned to
  // devices. There are num_cores_per_replica computations, and each one will be
  // sent and executed to the set of replica device numbers described in the
  // DeviceAssignment proto.
  DeviceAssignment device_assignment = 6;
  // The debugging options to be passed to the XLA compilation process.
  xla.DebugOptions debug_options = 7;

  // Everything inside Experimental is subject to change and is not subject
  // to API stability guarantees in
  // https://www.tensorflow.org/guide/version_compat.
  message Experimental {
    message UpdateIndexPair {
      int32 index = 1;
      bool updated = 2;
    }

    // stateful_input_indices is only useful when using XRT-compiled
    // programs together with standard TensorFlow TPU execution ops, so should
    // be ignored by most clients.
    //
    // Optionally the client can pass information about which inputs
    // to the computation are updates to "stateful" quantities. Each
    // element of stateful_input_indices includes an index indicating
    // which input argument it corresponds to, and a bool indicating
    // whether the value is updated or not. If the XRT computation is
    // going to be used with a TensorFlow TPU execution op then an
    // input index must be present for each input that will correspond
    // to a resource variable in the execution op, and may not be
    // present for any other input.
    repeated UpdateIndexPair stateful_input_indices = 1;
  }

  Experimental experimental = 8;
}

// Options and XLA computation for a compilation.
message XLAComputation {
  XLAComputationConfig config = 1;
  xla.HloSnapshot hlo_snapshot = 2;
}

// Literal to allocate space for, and transfer to, device memory.
message XLAAllocation {
  reserved 1;
  xla.LiteralProto value = 2;
}

// Node in a tree describing a tuple constructed from input handles. A
// node is an internal node if tuples is non-empty, in which case
// input_index and release_input_handle are ignored. Otherwise a node
// is a leaf node. Each leaf XLATupleNode is the index of an input
// which corresponds to a handle that will be grafted onto the output
// tuple at that location. If release_input_handle is true that input
// handle will be released and become invalid.  Inputs may be repeated
// in which case leaves of the output tuple will alias. If an input is
// repeated, release_input_handle must be false for every leaf where
// that input appears.
//
// For example, if input 0 has shape {} and input 1 has shape {2,3}
// then the XLATupleNode with structure {1,{0,1}} corresponds to a
// tuple with shape {{2,3},{{},{2,3}}}.
message XLATupleNode {
  int32 input_index = 1;
  bool release_input_handle = 2;
  repeated XLATupleNode tuples = 3;
}

message CommonExecutionConfig {
  // The replica index this execute is driving.
  int32 replica_id = 1;
  // Mapping local device ordinals to global replica IDs.
  // local_replica_mapping[LOCAL_DEVICE_ORDINAL] = GLOBAL_REPLICA_ID
  repeated int32 local_replica_mapping = 2;
  // The execution run ID used to correlate different XRT execute operations
  // happeining in parallel from different threads.
  int64 run_id = 3;
}

// Options for an XLA execution.
message XRTExecutionConfig {
  // Local device to run on. This is present because the execute Op
  // may be placed on a device such as CPU or TPU_SYSTEM that
  // logically manages multiple cores.
  int32 device_ordinal = 1;
  // Which model-parallel computation to run from the compiled bundle.
  int32 core_index_in_replica = 2;
  // Optional key to disambiguate between executions. This is only
  // needed if multiple host send/recvs may be outstanding
  // concurrently with executions.
  string execution_instance_key = 3;
  // If non-zero, rng_seed to reset the core with.
  uint32 rng_seed = 4;
  // If true, release allocation handles on the inputs after running.
  bool release_input_handles = 5;
  // If true, release the handle to the computation after running.
  bool release_compilation_handle = 6;
  // If set to true, and the result shape is a tuple, then instead of returning
  // a single tuple allocation the execution will return a vector of
  // allocations, one for each of the first-level elements of the result tuple.
  bool return_exploded_tuple = 7;
  reserved 8;
  // The common configuration for XRT execute operations.
  CommonExecutionConfig common_config = 9;
}

message XRTChainedExecuteConfig {
  // If non-zero, rng_seed to reset the core with.
  uint32 rng_seed = 1;
  // Which model-parallel computation to run from the compiled bundle.
  int32 core_index_in_replica = 2;
  // Optional key to disambiguate between executions. This is only needed if
  // multiple host send/recvs may be outstanding concurrently with executions.
  string execution_instance_key = 3;
  reserved 4;
  // The common configuration for XRT execute operations.
  CommonExecutionConfig common_config = 5;
}

// A single chained execute operation. An operation can either be a device data
// load, or an existing (as in, previously compiled and accessible via its int64
// handle) XLA computation execution.
message XRTChainedExecuteOp {
  // Represents an input for this operation.
  message Input {
    // The index within the XRTChainedExecutePlan.ops post-order of the source
    // operation for this input.
    int64 op_index = 1;
    // The output index of the value generated by the operation at op_index.
    // Zero (default value) means no index ({}) while if an indexing is
    // required, output_index needs to be set to index+1.
    // Thanks proto3!
    int64 output_index = 2;
  }
  // Represents an output of the XRTChainedExecute operation, which should
  // originate by the output of this operation.
  message Output {
    // The index in the value generated by this operation, which should be
    // forwarded as XRTChainedExecute output. If output_index is zero (default
    // value) the whole output will be used as result. This means that if the
    // output shape is a tuple, the result will be the full tuple. Otherwise the
    // real sub-tuple index will be output_index - 1.
    int64 output_index = 1;
    // The index in the vector of the results returned by the XRTChainedExecute
    // operation, where this output should be forwarded.
    int64 result_index = 2;
  }

  oneof op_oneof {
    // The handle to an existing XRT device data.
    int64 data_handle = 1;
    // The handle to an existing XRT compiled computation.
    int64 computation_handle = 2;
  }
  // The outputs of this XRTChainedExecuteOp operation.
  repeated Output outputs = 3;
  // The inputs of this XRTChainedExecuteOp operation. If data_handle is set,
  // there are no inputs.
  repeated Input inputs = 4;
}

// Execution plan for the XRTChainedExecute operation.
message XRTChainedExecutePlan {
  // The post order with the XRT computations to be executed.
  repeated XRTChainedExecuteOp ops = 1;
}

// The message used to encode the options for the XRTMetricsCollect operation.
message XRTMetricsCollect {
  // A list of regular expressions to match the metric names. Empty means to
  // return all the metrics reported by the collection registry.
  repeated string metrics_regex = 1;
}

message Percentiles {
  message Point {
    // In the [0, 100] range.
    double percentile = 1;
    double value = 2;
  }

  // The time (in nanoseconds) of the first sample within the samples buffer.
  uint64 start_nstime = 1;
  // The time (in nanoseconds) of the last sample within the samples buffer.
  uint64 end_nstime = 2;
  // The minimum value of the samples within the samples buffer.
  double min_value = 3;
  // The maximum value of the samples within the samples buffer.
  double max_value = 4;
  // The mean value of the samples within the samples buffer.
  double mean = 5;
  // The stndard deviation of the samples within the samples buffer.
  double stddev = 6;
  // The number samples within the samples buffer.
  uint64 num_samples = 7;
  // The total number of times this metrics has been posted a value to.
  uint64 total_samples = 8;
  // The sum of all the posted values.
  double accumulator = 9;
  // The percentile points reported by the metric.
  repeated Point points = 10;
}

message MetricValues {
  enum UnitOfMeasure {
    INVALID = 0;
    NUMBER = 1;
    TIME = 2;
    BYTES = 3;
  }

  // The metric name.
  string name = 1;

  oneof values_oneof {
    Percentiles percentiles_value = 2;
    int64 int64_value = 3;
  }

  UnitOfMeasure unit_of_measure = 4;
}

message MetricsReport {
  repeated MetricValues metrics = 1;
}

message MemoryInfo {
  // The total memory on a device, in KB.
  int64 kb_total = 1;
  // The free memory on a device, in KB.
  int64 kb_free = 2;
}
